classdef TradingStrategy < matlab.mixin.Heterogeneous
    %This is an interface for trading strategies.
    %All trading strategies inherit from the interface
    %This interface defines the functions common to all trading strategies
    %such as to obtain active share, turnover, performance summary statistics
    %and risk summary statistics.
    
    properties
        weights
        strategyReturns
        strategyExpectedReturns
        strategyExpectedVariance
        turnover
        negativeWealth
        shrink=0
        completePort=0
        deltaSeries
        transactionCosts=NaN;
        ids
        rf
        activeShare
        Means
        Covs
        storeMoments=false
        TransactionCostOptimize=false;
        riskAversion = 3;
        externalMoments=false;
        externalMeans;
        externalCovs;

    end
    
    methods
        strategyCalculate(obj)
        
        function sumStats=summary(obj,start_loc)
            
            sumStats=zeros(1,22);
            strategyReturnsUsed = obj.strategyReturns(start_loc:end);
            strategyReturnsUsed = strategyReturnsUsed(isnan(strategyReturnsUsed)==0);
            end_loc = start_loc + size(strategyReturnsUsed,1)-1;
            
            sumStats(1,1)=mean(strategyReturnsUsed)*1200;
            sumStats(1,2)=std(strategyReturnsUsed)*sqrt(12)*100;
            sumStats(1,3)=sumStats(1,1)/sumStats(1,2);
            
            %sumStats(1,4) = mean(obj.activeShare(start_loc:end_loc))*100;
            sumStats(1,5)=nanmean(obj.turnover(start_loc:end))*100;
            sumStats(1,6)=mean(max(obj.weights(:,start_loc:end_loc)))*100;
            sumStats(1,7)=mean(min(obj.weights(:,start_loc:end_loc)))*100;
            sumStats(1,8)=mean(std(obj.weights(:,start_loc:end_loc)))*100;
            sumStats(1,9)=std(std(obj.weights(:,start_loc:end_loc)))*100;
            wgt = obj.weights(:,start_loc:end_loc);
            
            sumStats(1,11)=mean(sum(wgt.*(wgt<0)))*100;
            sumStats(1,10)=length(strategyReturnsUsed);
            
            
        end
        
        function [tab_risk, tab_pval]= riskSummary(obj,start_loc)
            tab_risk=NaN(1,9);
            tab_pval=NaN(1,9);
            
            a=obj.strategyExpectedVariance;%the expected variance
            b=obj.strategyReturns;%the actual OOS returns
            
            periods=start_loc:length(a);
            
            c=obj.strategyExpectedReturns;%the expected returns
            
            a = a(periods);
            b = b(periods);
            c = c(periods);
            
            c = c(isnan(b)==0);
            a = a(isnan(b)==0);
            b = b(isnan(b)==0);
            
            
            tab_risk(1)=mean(c)*12;
            tab_risk(2)=(mean(b)*12);
            tab_risk(3)=sqrt(mean((b - c).^2));
            
            tab_risk(4)=mean(sqrt(12*a));
            
            tab_risk(5)=sqrt(12*var(b));
            z_stats = -abs(mean(b - c))/(std(b - c)/sqrt(length(b)));
            pvals = normcdf(z_stats)*2;
            tab_pval(2)= mean(pvals);
            
            
            d=(b-c)./sqrt(a);
            d=d(isnan(d)==0);
            tab_risk(6)=(sum(d<norminv(0.01))/length(d));
            tab_pval(6)=hitrate_pvalue(sum(d<norminv(0.01))/length(d),0.01, length(d));
            
            tab_risk(7)=sum(d<norminv(0.05))/length(d);
            tab_pval(7)=hitrate_pvalue(sum(d<norminv(0.05))/length(d),0.05, length(d));
            tab_risk(8)=sum(d>norminv(0.95))/length(d);
            tab_pval(8)=hitrate_pvalue(sum(d>norminv(0.95))/length(d),0.95, length(d));
            tab_risk(9)=sum(d>norminv(0.99))/length(d);
            tab_pval(9)=hitrate_pvalue(sum(d>norminv(0.99))/length(d),0.99, length(d));
            
        end
        
        function [tab_out]= performanceRiskSummary(obj,start_loc)
            tab_out=NaN(1,2);
            
            strategyReturnsUsed = obj.strategyReturns(start_loc:end);
            strategyReturnsUsed = strategyReturnsUsed(isnan(strategyReturnsUsed)==0);
            
            %size(strategyReturnsUsed)
            
            avg=mean(strategyReturnsUsed)*1200;
            stdret=std(strategyReturnsUsed)*sqrt(12)*100;
            tab_out(1,1)=avg/stdret;
            
            a=obj.strategyExpectedVariance;%the expected variance
            b=obj.strategyReturns;%the actual OOS returns
            
            periods=start_loc:length(a);
            
            c=obj.strategyExpectedReturns;%the expected returns
            
            a = a(periods);
            b = b(periods);
            c = c(periods);
            
            a = a(isnan(b)==0);
            b = b(isnan(b)==0);
            
            evar=mean(sqrt(12*a));
            
            rvar=sqrt(12*var(b));
            
            tab_out(1,2)=rvar/evar;
            
        end
        
        function basisMeans = getMeanReturns(obj,r_past)
            rawMeans=mean(r_past)';
            
            if obj.shrink ~= 1
                basisMeans = rawMeans;
                return;
            end
            N = size(rawMeans,1);
            T = size(r_past,1);
            
            %Shrinkage using James-Stein Estimator
            grand_mean = nanmean(rawMeans);
            sigma2_hist = var(reshape(r_past,[1 N*T]));
            shrinkJS = max(0,1 - ((N-3)*((sigma2_hist)/T)/(N*var(rawMeans))));
            basisMeans = grand_mean + shrinkJS*(rawMeans - grand_mean);
        end
        
        function weightsMV = setMVStrategyParams(obj,basisCov,basisMeans)
            
            weightsMV=basisCov\basisMeans;
            weightsMV = adjustWeightsForRisklessRateSettings(obj,weightsMV);
             
        end
        
        
        function weightsMV = adjustWeightsForRisklessRateSettings(obj,weightsMV)
            
            if obj.completePort == 1
                weightsMV=weightsMV/obj.riskAversion;
            else
                weightsMV=weightsMV/abs(sum(weightsMV));
            end
            
        end
        
        function obj = executeCommonElements(obj,t,weightsMV,basisCov,basisMeans, r_plus1,r_curr)
            
            if ~isnan(sum(basisMeans))
                obj.strategyExpectedVariance(t)=weightsMV'*basisCov*weightsMV;
                obj.strategyExpectedReturns(t)=weightsMV'*basisMeans;
            end
            
            obj.strategyReturns(t)=weightsMV'*r_plus1;
        end
        
    end
    
    
end
